--[[ Copyright (C) 2018 Google Inc.

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
]]

--[[ 2AFC task.
In each trial the agent must select which of the two presented patterns is a
structured glass pattern. The other pattern will be unstructured noise. Glass
patterns have a range of coherence levels, i.e., proportion p of dots were
generated by the glass pattern and 1-p are random.
]]

local game = require 'dmlab.system.game'
local helpers = require 'common.helpers'
local image = require 'dmlab.system.image'
local point_and_click = require 'factories.psychlab.point_and_click'
local psychlab_factory = require 'factories.psychlab.factory'
local psychlab_helpers = require 'factories.psychlab.helpers'
local log = require 'common.log'
local random = require 'common.random'
local set = require 'common.set'
local tensor = require 'dmlab.system.tensor'

local TIME_TO_FIXATE_CROSS = 1 -- in frames
local FAST_INTER_TRIAL_INTERVAL = 1 -- in frames
local SCREEN_SIZE = {width = 512, height = 512}
local BG_COLOR = 127

local FIXATION_REWARD = 0
local CORRECT_REWARD = 1
local INCORRECT_REWARD = 0

local NATIVE_TARGET_SIZE = 128 -- in native pixels, before resizing
local TARGET_SCREEN_PROPORTION = .48 -- in 0-1 coordinates
local NUM_DOTS = 128
local AVAILABLE_FRACTION = .9

local FIXATION_SIZE = .1
local FIXATION_COLOR = {255, 0, 0} -- RGB
local CENTER = {.5, .5}
local LEFT_STIM_LOC = {.248, CENTER[2]}
local RIGHT_STIM_LOC = {.752, CENTER[2]}
local BUTTON_SIZE = 0.1

local COHERENCE_VALUES = {1, 1, .90, .80, .70, .60, .50, .30, .15, .05}

-- Staircase parameters
local PROBE_PROBABILITY = 0.1
local FRACTION_TO_ADVANCE = 0.75
local FRACTION_TO_REMAIN = 0.5

local MAX_IDLE_STEPS = 500

local TRIALS_PER_EPISODE_CAP = 125

local CONCENTRIC = '1'
local RADIAL = '0'
local HYPERBOLIC = '0'
local SPIRAL = '0'

local ALLOW_MIXED_POLARITY = false

local factory = {}

function factory.createLevelApi(kwargs)
  kwargs.timeToFixateCross = kwargs.timeToFixateCross or TIME_TO_FIXATE_CROSS
  kwargs.fastInterTrialInterval = kwargs.fastInterTrialInterval or
    FAST_INTER_TRIAL_INTERVAL
  kwargs.screenSize = kwargs.screenSize or SCREEN_SIZE
  kwargs.bgColor = kwargs.bgColor or BG_COLOR
  kwargs.fixationReward = kwargs.fixationReward or FIXATION_REWARD
  kwargs.correctReward = kwargs.correctReward or CORRECT_REWARD
  kwargs.incorrectReward = kwargs.incorrectReward or INCORRECT_REWARD
  kwargs.nativeTargetSize = kwargs.nativeTargetSize or NATIVE_TARGET_SIZE
  kwargs.targetScreenProportion = kwargs.targetScreenProportion or
    TARGET_SCREEN_PROPORTION
  kwargs.numDots = kwargs.numDots or NUM_DOTS
  kwargs.availableFraction = kwargs.availableFraction or AVAILABLE_FRACTION
  kwargs.fixationSize = kwargs.fixationSize or FIXATION_SIZE
  kwargs.fixationColor = kwargs.fixationColor or FIXATION_COLOR
  kwargs.leftStimLoc = kwargs.leftStimLoc or LEFT_STIM_LOC
  kwargs.rightStimLoc = kwargs.rightStimLoc or RIGHT_STIM_LOC
  kwargs.buttonSize = kwargs.buttonSize or BUTTON_SIZE
  kwargs.coherenceValues = kwargs.coherenceValues or COHERENCE_VALUES
  kwargs.probeProbability = kwargs.probeProbability or PROBE_PROBABILITY
  kwargs.fractionToAdvance = kwargs.fractionToAdvance or FRACTION_TO_ADVANCE
  kwargs.fractionToRemain = kwargs.fractionToRemain or FRACTION_TO_REMAIN
  kwargs.maxIdleSteps = kwargs.maxIdleSteps or MAX_IDLE_STEPS
  kwargs.trialsPerEpisodeCap = kwargs.trialsPerEpisodeCap or
    TRIALS_PER_EPISODE_CAP
  kwargs.concentric = CONCENTRIC
  kwargs.radial = RADIAL
  kwargs.hyperbolic = HYPERBOLIC
  kwargs.spiral = SPIRAL
  kwargs.allowMixedPolarity = ALLOW_MIXED_POLARITY

  -- converts a nil or string flag to a boolean, if nil or '0', then returns
  -- false
  local function stringOrNilToBool(flag)
    return tonumber(flag or '0') > 0
  end

  local TRANSFORMATION_PARAMS = {
      concentric = {
          a = 1,
          b = 1,
          theta = math.pi / 36
      },

      radial = {
          a = 1.05,
          b = 1.05,
          theta = 0
      },

      hyperbolic = {
          a = 0.95,
          b = 1.05,
          theta = 0
      },

      spiral = {
          a = 1.05,
          b = 1.05,
          theta = math.pi / 36
      }
  }

  -- Get random dot positions
  -- Returns a table, not a tensor
  local function createDotPattern(opt)
    assert(type(opt.size) == 'number')
    assert(type(opt.numDots) == 'number')
    local dotPositions = opt.output or {}
    local size = opt.size
    local numDots = opt.numDots

    local range = size / 2
    for ii = 1, numDots do
      local point = {random:uniformReal(-range, range),
                     random:uniformReal(-range, range)}
      dotPositions[#dotPositions + 1] = point
    end
    return dotPositions
  end

  local function euclideanDistance(a, b)
    return math.sqrt((a[1] - b[1]) ^ 2 + (a[2] - b[2]) ^ 2)
  end

  local function randomShift(point, distance)
    local orientation = random:uniformReal(0, 2 * math.pi)
    return {
        point[1] + distance * math.sin(orientation),
        point[2] + distance * math.cos(orientation)
    }
  end

  -- randomizeOrientation is used to ensure mean local dot spacing is the same
  -- for both targets and distractors. See Wilson and Wilkinson 1998 for
  -- details.
  local function transformDotPattern(opt)
    assert(type(opt.size) == 'number')
    assert(type(opt.transformation) == 'function',
           "Expected function, got " .. type(opt.transformation))
    assert(type(opt.randomizeOrientation) == 'boolean')

    local positions = opt.positions
    local transform = opt.transformation
    local size = opt.size
    local randomizeOrientation = opt.randomizeOrientation
    local transformedDotPositions = opt.output or {}

    for ii = 1, #positions do
      local point = positions[ii]
      local transformedPoint = transform(point)
      if randomizeOrientation then
        local distance = euclideanDistance(point, transformedPoint)
        transformedPoint = randomShift(point, distance)
      end
      transformedDotPositions[#transformedDotPositions + 1] = transformedPoint
    end
    return transformedDotPositions
  end

  local function drawDots(patternTensor, dotPositions)
    local patternSize = patternTensor:shape()[1]
    local halfPatternSize = patternSize / 2
    for ii, point in ipairs(dotPositions) do
      local y = math.ceil(point[1] + 0.5) + halfPatternSize
      local x = math.ceil(point[2] + 0.5) + halfPatternSize
      if y > 0 and x > 0 and y < patternSize and x < patternSize then
        patternTensor(y, x):val(255)
      end
    end
    return patternTensor
  end

  -- draw all dots
  local function renderPattern(size, dotPositions)
    return drawDots(tensor.ByteTensor(size, size), dotPositions)
  end

  -- keep the light and dark patterns separate for now
  local function renderReversePattern(size, dotPositions)
    local midpoint = math.ceil(#dotPositions / 2 + 0.5)

    local darkDots = {}
    for ii = 1, midpoint do darkDots[ii] = dotPositions[ii] end

    local lightDots = {}
    for ii = midpoint + 1, #dotPositions do
      lightDots[#lightDots + 1] = dotPositions[ii]
    end
    local pattern = {
        dark = drawDots(tensor.ByteTensor(size, size), darkDots),
        light = drawDots(tensor.ByteTensor(size, size), lightDots)
    }
    return pattern
  end

  local function getTransformation(a, b, theta)
    return function (point)
      local y, x = point[1], point[2]
      return {
          a * x * math.sin(theta) + b * y * math.cos(theta),
          a * x * math.cos(theta) - b * y * math.sin(theta)
      }
    end
  end

  local VALID_POLARITY = set.Set({'black', 'white', 'mixed'})

  local function createGlassPattern(opt)
    assert(type(opt.size) == 'number')
    assert(type(opt.numDots) == 'number')
    assert(type(opt.a) == 'number')
    assert(type(opt.b) == 'number')
    assert(type(opt.theta) == 'number')
    assert(type(opt.randomizeOrientation) == 'boolean')
    assert(opt.coherence >= 0 and opt.coherence <= 1)
    assert(VALID_POLARITY[opt.polarity])

    local numCoherentDots = math.ceil(opt.coherence * opt.numDots + 0.5)
    local numNoiseDots = opt.numDots - numCoherentDots

    -- get transformation closure
    local transformation = getTransformation(opt.a, opt.b, opt.theta)

    -- generate the coherent dot positions
    local dotPositions = createDotPattern{
        size = opt.size * kwargs.availableFraction,
        numDots = math.ceil(numCoherentDots / 2 + 0.5)
    }
    transformDotPattern{
        output = dotPositions,
        positions = dotPositions,
        transformation = transformation,
        size = opt.size * kwargs.availableFraction,
        randomizeOrientation = opt.randomizeOrientation
    }

    -- add the incoherent (noise) dots
    createDotPattern{
        output = dotPositions,
        size = opt.size * kwargs.availableFraction,
        numDots = numNoiseDots}

    -- Need to shuffle the dot positions so the random dots aren't all in the
    -- first half of the tensor. This matters because renderReversePattern will
    -- render the first half in black and the second half in white.
    dotPositions = random:shuffle(dotPositions)

    if opt.polarity == 'mixed' then
      return renderReversePattern(opt.size, dotPositions)
    else
      return renderPattern(opt.size, dotPositions)
    end
  end

  --[[ Glass pattern detection psychlab environment class
  ]]
  local env = {}
  env.__index = env

  setmetatable(env, {
      __call = function (cls, ...)
        local self = setmetatable({}, cls)
        self:_init(...)
        return self
      end
  })

  --[[ Function to define a one-dimensional adaptive staircase procedure
  (a 'class'). This procedure promotes from difficulty level (K) by testing
  num_to_test trials at level (K + 1) also interleave (K). If successful
  (> fractionToAdvance*K correct) then increment K . If < fractionToRemain*K
  correct on the (K) tests, then demote to (K - 1). Besides the base trials at
  level (K) and the testForAdvancement trials at level (K + 1), there are also
  probe trials. The probe trials may use stimuli from any level up to K.
  ]]
  local function initStaircase(opt)
    assert(type(opt.probeProbability) == 'number')
    assert(type(opt.fractionToAdvance) == 'number')
    assert(type(opt.fractionToRemain) == 'number')
    assert(type(opt.coherenceValues) == 'table')

    local staircase = {
        difficultyLevel = {coherence = 1},
        numTrialsPerTest = 1,
        probeProbability = opt.probeProbability,
        fractionToAdvance = opt.fractionToAdvance,
        fractionToRemain = opt.fractionToRemain,
        coherenceValues = opt.coherenceValues,
        _testTrialTypes = {'base', 'advance_coherence'},
        _lastLevel = #opt.coherenceValues - 1
    }

    function staircase._resetTestDomain(self)
      self.numTrialsPerTest = self.difficultyLevel.coherence

      local advance_coherence_level
      if self.difficultyLevel.coherence < self._lastLevel then
        advance_coherence_level = self.difficultyLevel.coherence + 1
      else
        advance_coherence_level = self.difficultyLevel.coherence
      end
      self._testDomain = {
          base = self.difficultyLevel,
          advance_coherence = {coherence = advance_coherence_level},
      }

      self._testOrder = random:shuffle({1, 2})
      self._testIndex = 1
      self._scores = {base = 0, advance_coherence = 0}
    end

    function staircase._updateLevel(self)
      local scoreToAdvance = self.fractionToAdvance * self.numTrialsPerTest
      local scoreToRemain = self.fractionToRemain * self.numTrialsPerTest
      if self._scores['advance_coherence'] >= scoreToAdvance then
        if self.difficultyLevel.coherence < self._lastLevel then
          self.difficultyLevel.coherence = self.difficultyLevel.coherence + 1
        end
      end
      if self._scores['base'] < scoreToRemain then
        if self.difficultyLevel.coherence > 1 then
          self.difficultyLevel.coherence = self.difficultyLevel.coherence - 1
        end
      end
    end

    function staircase._getNextTestTrial(self)
      local orderIndex = (self._testIndex - 1) % #self._testOrder + 1
      local trialType = self._testTrialTypes[self._testOrder[orderIndex]]
      local baseOrAdvanceLevel = self._testDomain[trialType]
      local coherence = self.coherenceValues[baseOrAdvanceLevel.coherence]
      self._testIndex = self._testIndex + 1
      return {coherence = coherence, trialType = trialType}
    end

    function staircase._getNextProbeTrial(self)
      local randomCoherenceIndex = random:uniformInt(1,
          self.difficultyLevel.coherence)
      return {
          coherence = self.coherenceValues[randomCoherenceIndex],
          trialType = 'probe'
      }
    end

    function staircase.getNextTrial(self)
      if random:uniformReal(0, 1) < self.probeProbability then
        return self:_getNextProbeTrial()
      else
        return self:_getNextTestTrial()
      end
    end

    -- 'staircase.step' is called at the end of each trial.
    function staircase.step(self, trialType, correct)
      if trialType ~= 'probe' then
        self._scores[trialType] = self._scores[trialType] + correct
      end
      if self._testIndex > #self._testOrder * self.numTrialsPerTest then
        self:_updateLevel()
        self:_resetTestDomain()
      end
    end

    staircase:_resetTestDomain()
    return staircase
  end

  local MAX_INT = math.pow(2, 32) - 1

  -- init is called at the start of each episode.
  function env:_init(pac, opts)
    self.screenSize = opts.screenSize
    log.info('opts passed to _init:\n' .. helpers.tostring(opts))

    -- Parse task parameters
    self.concentric = stringOrNilToBool(opts.concentric)
    self.radial = stringOrNilToBool(opts.radial)
    self.hyperbolic = stringOrNilToBool(opts.hyperbolic)
    self.spiral = stringOrNilToBool(opts.spiral)

    -- use the screenSize to compute the actual size in pixels for each image
    self.sizeInPixels = {
        targetHeight = kwargs.targetScreenProportion * self.screenSize.height,
        targetWidth = kwargs.targetScreenProportion * self.screenSize.width,
        fixationHeight = kwargs.fixationSize * self.screenSize.height,
        fixationWidth = kwargs.fixationSize * self.screenSize.width
    }

    self.patternTypes = {}
    self.coherenceValues = kwargs.coherenceValues

    for _, pattern in ipairs({'concentric',
                              'radial',
                              'hyperbolic',
                              'spiral'}) do
      if self[pattern] then
        log.info('Use ' .. pattern .. ' Glass Patterns')
        table.insert(self.patternTypes, pattern)
      end
    end

    if #self.patternTypes == 0 then
      error('Must select at least one pattern type, e.g. concentric')
    end

    self._timeoutIfIdle = opts.timeoutIfIdle
    self._stepsSinceInteraction = 0
    self:setupImages()

    -- handle to the point and click api
    self.pac = pac
  end

  -- reset is called after init. It is called only once per episode.
  -- Note: the episodeId passed to this function may not be correct if the job
  -- has resumed from a checkpoint after preemption.
  function env:reset(episodeId, seed)
    random:seed(seed)

    self.pac:setBackgroundColor{kwargs.bgColor, kwargs.bgColor, kwargs.bgColor}
    self.pac:clearWidgets()
    psychlab_helpers.addFixation(self, kwargs.fixationSize)
    self._latestParams = {}

    self.currentTrial = {}

    psychlab_helpers.setTrialsPerEpisodeCap(self, kwargs.trialsPerEpisodeCap)

    self.staircase = initStaircase{probeProbability = kwargs.probeProbability,
                                   fractionToAdvance = kwargs.fractionToAdvance,
                                   fractionToRemain = kwargs.fractionToRemain,
                                   coherenceValues = kwargs.coherenceValues}

    -- blockId groups together all rows written during the same episode
    self.blockId = random:uniformInt(0, MAX_INT)
  end

  -- Creates image Tensors for red/green/white/black buttons and fixation.
  function env:setupImages()
    self.images = {}

    self.images.fixation = psychlab_helpers.getFixationImage(self.screenSize,
        kwargs.bgColor, kwargs.fixationColor, kwargs.fixationSize)
    local h = kwargs.buttonSize * self.screenSize.height
    local w = kwargs.buttonSize * self.screenSize.width

    self.images.greenImage = tensor.ByteTensor(h, w, 3):fill{100, 255, 100}
    self.images.redImage = tensor.ByteTensor(h, w, 3):fill{255, 100, 100}
    self.images.blackImage = tensor.ByteTensor(h, w, 3)
  end

  function env:finishTrial(delay)
    self._stepsSinceInteraction = 0
    self.currentTrial.blockId = self.blockId
    self.currentTrial.reactionTime =
        game:episodeTimeSeconds() - self._currentTrialStartTime
    self.staircase:step(self.currentTrial.trialType, self.currentTrial.correct)

    psychlab_helpers.publishTrialData(self.currentTrial, kwargs.schema)
    psychlab_helpers.finishTrialCommon(self, delay, kwargs.fixationSize)
  end

  function env:fixationCallback(name, mousePos, hoverTime, userData)
    if hoverTime == kwargs.timeToFixateCross then
      self._stepsSinceInteraction = 0
      self.pac:addReward(kwargs.fixationReward)
      self.pac:removeWidget('fixation')
      self.pac:removeWidget('center_of_fixation')

      -- Measure reaction time from trial initiation
      self._currentTrialStartTime = game:episodeTimeSeconds()
      self.currentTrial.stepCount = 0

      self:addArray()
    end
  end

  function env:onHoverEndCorrect(name, mousePos, hoverTime, userData)
    -- Reward if this is the first "hoverEnd" event for this trial.
    self.currentTrial.response = name
    self.currentTrial.correct = 1
    self.pac:addReward(kwargs.correctReward)
    self:finishTrial(kwargs.fastInterTrialInterval)
  end

  function env:onHoverEndIncorrect(name, mousePos, hoverTime, userData)
    -- Reward if this is the first "hoverEnd" event for this trial.
    self.currentTrial.response = name
    self.currentTrial.correct = 0
    self.pac:addReward(kwargs.incorrectReward)
    self:finishTrial(kwargs.fastInterTrialInterval)
  end

  function env:correctResponseCallback(name, mousePos, hoverTime, userData)
    self.pac:updateWidget(name, self.images.greenImage)
  end

  function env:incorrectResponseCallback(name, mousePos, hoverTime, userData)
    self.pac:updateWidget(name, self.images.redImage)
  end

  -- Take white-on-black glass patterns and expands the dots a bit.
  -- Note: this convoluted combination of operations was discovered through
  -- trial and error. It produces easy-to-see glass patterns down to 85X84
  -- resolution.
  function env:_expandDots(pattern)
    assert(#pattern:shape() == 2)
    local h, w = unpack(pattern:shape())
    local singleChannelImage = pattern:reshape({h, w, 1})
    local scaleDown = image.scale(singleChannelImage, 64, 64)
    scaleDown:apply(function(v) return v > 0 and 255 or 0 end)
    local scaleUp = image.scale(scaleDown, self.sizeInPixels.targetHeight,
                                self.sizeInPixels.targetWidth)
    scaleUp:apply(function(v) return v > 0 and 255 or 0 end)
    return scaleUp
  end

  function env:_prepareGlassPattern(pattern)
    local finalPattern
    if self.currentTrial.polarity == 'mixed' then
      local darkPattern = self:_expandDots(pattern.dark)
      local lightPattern = self:_expandDots(pattern.light)
      darkPattern:apply(function(v) return 255 - v > 0 and 1 or 0 end)
      lightPattern:apply(function(v) return v > 0 and kwargs.bgColor or 0 end)
      finalPattern = tensor.ByteTensor(unpack(lightPattern:shape())):
          fill(kwargs.bgColor):cmul(darkPattern):cadd(lightPattern)
    elseif self.currentTrial.polarity == 'black' then
      -- Black dots on grey background
      finalPattern = self:_expandDots(pattern):apply(
          function(v) return v < 255 and kwargs.bgColor or 0 end)
    elseif self.currentTrial.polarity == 'white' then
      -- White dots on grey background
      finalPattern = self:_expandDots(pattern):apply(
          function(v) return v > 0 and kwargs.bgColor * 2 or kwargs.bgColor end)
    end

    -- Turn {height, width} greyscale tensor into {h, w, 3} rgb
    local height, width = unpack(finalPattern:shape())
    local rgb = tensor.ByteTensor(height, width, 3)
    finalPattern:applyIndexed(function(v, index)
          rgb(index[1], index[2]):fill(v)
    end)
    return rgb
  end

  function env:_placeOnePattern(pattern, location, correct)
    local widgetName = correct and 'target' or 'distractor'
    -- place the image
    self.pac:addWidget{
        name = widgetName,
        image = self:_prepareGlassPattern(pattern),
        pos = psychlab_helpers.getUpperLeftFromCenter(location,
            kwargs.targetScreenProportion),
        size = {kwargs.targetScreenProportion, kwargs.targetScreenProportion}
    }
  end

  --[[
    `patternType` from {'concentric', 'radial', 'hyperbolic', 'spiral'}.
    `coherence` (float) in [0, 1] where 0 is fully incoherent (all noise dots)
    and 1 is fully coherent (all dots paired by the transformation).
    `correctLocation` from {'left', 'right'}. Place the correct glass pattern
    here and the noise pattern at the other location.
  ]]
  function env:addStims(patternType, coherence, polarity, correctLocation)
    self._latestParams = TRANSFORMATION_PARAMS[patternType]

    local glassPattern = createGlassPattern{
        size = kwargs.nativeTargetSize,
        numDots = kwargs.numDots,
        coherence = coherence,
        a = self._latestParams.a,
        b = self._latestParams.b,
        theta = self._latestParams.theta,
        randomizeOrientation = false,
        polarity = polarity
    }
    local noisePattern = createGlassPattern{
        size = kwargs.nativeTargetSize,
        numDots = kwargs.numDots,
        coherence = coherence,
        a = self._latestParams.a,
        b = self._latestParams.b,
        theta = self._latestParams.theta,
        randomizeOrientation = true,
        polarity = polarity
    }

    if correctLocation == 'LEFT' then
      self:_placeOnePattern(glassPattern, kwargs.leftStimLoc, true)
      self:_placeOnePattern(noisePattern, kwargs.rightStimLoc, false)
    elseif correctLocation == 'RIGHT' then
      self:_placeOnePattern(glassPattern, kwargs.rightStimLoc, true)
      self:_placeOnePattern(noisePattern, kwargs.leftStimLoc, false)
    end
  end

  function env:addResponseButtons(correctLocation)
    local buttonPosX = 0.5 - kwargs.buttonSize * 1.5
    local buttonSize = {kwargs.buttonSize, kwargs.buttonSize}

    local leftResponseCallback, rightResponseCallback
    local hoverEndRight, hoverEndLeft
    if correctLocation == 'RIGHT' then
      rightResponseCallback = self.correctResponseCallback
      hoverEndRight = self.onHoverEndCorrect
      leftResponseCallback = self.incorrectResponseCallback
      hoverEndLeft = self.onHoverEndIncorrect
    elseif correctLocation == 'LEFT' then
      rightResponseCallback = self.incorrectResponseCallback
      hoverEndRight = self.onHoverEndIncorrect
      leftResponseCallback = self.correctResponseCallback
      hoverEndLeft = self.onHoverEndCorrect
    else
      error("Unknown location: ", correctLocation)
    end

    self.pac:addWidget{
        name = 'respond_right',
        image = self.images.blackImage,
        pos = {1 - buttonPosX - kwargs.buttonSize, 1 - kwargs.buttonSize},
        size = buttonSize,
        mouseHoverCallback = rightResponseCallback,
        mouseHoverEndCallback = hoverEndRight,
    }
    self.pac:addWidget{
        name = 'respond_left',
        image = self.images.blackImage,
        pos = {buttonPosX, 1 - kwargs.buttonSize},
        size = buttonSize,
        mouseHoverCallback = leftResponseCallback,
        mouseHoverEndCallback = hoverEndLeft,
    }
  end

  function env:addArray()
    self.currentTrial.patternType = random:choice(self.patternTypes)
    local polarities = {'black', 'white'}
    if kwargs.allowMixedPolarity then table.insert(polarities, 'mixed') end
    self.currentTrial.polarity = random:choice(polarities)
    self.currentTrial.correctLocation = random:choice({'LEFT', 'RIGHT'})

    local trialData = self.staircase:getNextTrial()
    self.currentTrial.coherence = trialData.coherence
    self.currentTrial.trialType = trialData.trialType

    self:addStims(self.currentTrial.patternType,
                  self.currentTrial.coherence,
                  self.currentTrial.polarity,
                  self.currentTrial.correctLocation)
    self:addResponseButtons(self.currentTrial.correctLocation)
  end

  function env:removeArray()
    -- remove the target and distractor images as well as the response buttons.
    self.pac:removeWidget('target')
    self.pac:removeWidget('distractor')
    self.pac:removeWidget('respond_right')
    self.pac:removeWidget('respond_left')
  end

  function env:step(lookingAtScreen)
    if self.currentTrial.stepCount == nil then self:fixationCallback() end

    if self.currentTrial.stepCount ~= nil then
      self.currentTrial.stepCount = self.currentTrial.stepCount + 1
    end

    -- If too long since looking at the screen, then end episode. This
    -- should speed up the early stages of training.
    self._stepsSinceInteraction = self._stepsSinceInteraction + 1
    if self._timeoutIfIdle and
        self._stepsSinceInteraction > kwargs.maxIdleSteps then
      self.pac:endEpisode()
    end
  end

  return psychlab_factory.createLevelApi{
      env = point_and_click,
      envOpts = {
          environment = env, screenSize = kwargs.screenSize,
          concentric = kwargs.concentric, radial = kwargs.radial,
          hyperbolic = kwargs.hyperbolic, spiral = kwargs.spiral
      },
      episodeLengthSeconds = 180
  }
end

return factory
